"""
TAPE: Assessing Few-shot Russian Language Understanding
https://arxiv.org/pdf/2210.12813.pdf

TAPE (Text Attack and Perturbation Evaluation) is a novel benchmark for few-shot
Russian language understanding evaluation that includes six complex NLU tasks, covering
multi-hop reasoning, ethical concepts, logic and commonsense knowledge.

Homepage: https://tape-benchmark.com/
"""

import numpy as np
import transformers.data.metrics.squad_metrics as squad_metrics

from lm_eval.base import Task, MultipleChoiceTask, rf
from lm_eval.metrics import mean, metric_max_over_ground_truths, f1_score_multiclass_macro


class CheGeKa(Task):
    VERSION = 0
    DATASET_NAME = "chegeka"

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return False

    def has_test_docs(self):
        return True

    def training_docs(self):
        if self.has_training_docs():
            if self._training_docs is None:
                self._training_docs = list(self.dataset["train"])
            return self._training_docs

    def test_docs(self):
        if self.has_test_docs():
            return list(self.dataset["test"])

    def doc_to_text(self, doc):
        prompt = (
            doc["instruction"]
            .format(
                topic=doc["inputs"]["topic"],
                text=doc["inputs"]["text"],
            )
            .strip()
        )
        return prompt

    def doc_to_text_without_instruction(self, doc):
        prompt = 'Категория "{topic}"\nВопрос: {text}\nОтвет:'.format(
            topic=doc["inputs"]["topic"],
            text=doc["inputs"]["text"],
        ).strip()
        return prompt

    def doc_to_target(self, doc):
        return " " + doc["outputs"]

    def construct_requests(self, doc, ctx):
        """Uses RequestFactory to construct Requests and returns an iterable of
        Requests which will be sent to the LM.

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param ctx: str
            The context string, generated by fewshot_context. This includes the natural
            language description, as well as the few shot examples, and the question
            part of the document for `doc`.
        """
        return rf.greedy_until(ctx, {"until": ["."]})

    def process_results(self, doc, results):
        # Chegeka's evaluation is actually deceptively simple:
        # - Evaluate the accuracy and token F1 PER EXAMPLE
        # - Average over all examples

        gold_label_set = doc["outputs"].split(";")
        pred = results[0]

        f1 = metric_max_over_ground_truths(squad_metrics.compute_f1, pred, gold_label_set)
        em = metric_max_over_ground_truths(squad_metrics.compute_exact, pred, gold_label_set)

        return {
            "f1": f1,
            "em": em,
        }

    def higher_is_better(self):
        return {
            "f1": True,
            "em": True,
        }

    def aggregation(self):
        return {
            "f1": mean,
            "em": mean,
        }


class MultiQ(Task):
    VERSION = 0
    DATASET_NAME = "multiq"

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return False

    def has_test_docs(self):
        return True

    def training_docs(self):
        if self.has_training_docs():
            if self._training_docs is None:
                self._training_docs = list(self.dataset["train"])
            return self._training_docs

    def test_docs(self):
        if self.has_test_docs():
            return list(self.dataset["test"])

    def doc_to_text(self, doc):
        prompt = (
            doc["instruction"]
            .format(
                question=doc["inputs"]["question"],
                support_text=doc["inputs"]["support_text"],
                text=doc["inputs"]["text"],
            )
            .strip()
        )
        return prompt

    def doc_to_target(self, doc):
        return " " + doc["outputs"][0]["segment"]

    def construct_requests(self, doc, ctx):
        """Uses RequestFactory to construct Requests and returns an iterable of
        Requests which will be sent to the LM.

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param ctx: str
            The context string, generated by fewshot_context. This includes the natural
            language description, as well as the few shot examples, and the question
            part of the document for `doc`.
        """
        return rf.greedy_until(ctx, {"until": ["."]})

    def process_results(self, doc, results):
        gold_label_set = [answer["segment"] for answer in doc["outputs"]]
        pred = results[0]

        f1 = metric_max_over_ground_truths(squad_metrics.compute_f1, pred, gold_label_set)
        em = metric_max_over_ground_truths(squad_metrics.compute_exact, pred, gold_label_set)

        return {
            "f1": f1,
            "em": em,
        }

    def higher_is_better(self):
        return {
            "f1": True,
            "em": True,
        }

    def aggregation(self):
        return {
            "f1": mean,
            "em": mean,
        }


class RuWorldTree(MultipleChoiceTask):
    VERSION = 0
    DATASET_NAME = "ruworldtree"

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return False

    def has_test_docs(self):
        return True

    def training_docs(self):
        if self.has_training_docs():
            if self._training_docs is None:
                self._training_docs = list(map(self._process_doc, self.dataset["train"]))
            return self._training_docs

    def validation_docs(self):
        if self.has_validation_docs():
            return map(self._process_doc, self.dataset["validation"])

    def test_docs(self):
        if self.has_test_docs():
            return map(self._process_doc, self.dataset["test"])

    def _process_doc(self, doc):
        query = doc["instruction"].format(**doc["inputs"]).strip()
        choices = list("ABCD")
        if doc["outputs"]:
            gold = choices.index(doc["outputs"])
        else:
            gold = ""

        doc["query"] = query
        doc["choices"] = choices
        doc["gold"] = gold
        return doc

    def doc_to_text(self, doc):
        return doc["query"]

    def doc_to_text_without_instruction(self, doc):
        prompt = "{question}\nA) {option_a}\nB) {option_b}\nC) {option_c}\nD) {option_d}\nОтвет:".format(
            **doc["inputs"]
        ).strip()
        return prompt

    def doc_to_target(self, doc):
        if isinstance(doc["gold"], int):
            gold = doc["choices"][doc["gold"]]
        else:
            gold = ""
        return " " + gold

    def should_decontaminate(self):
        return True

    def doc_to_decontamination_query(self, doc):
        return doc["query"]

    def doc_to_target(self, doc):
        return " " + doc["choices"][doc["gold"]]

    def construct_requests(self, doc, ctx):
        lls = [rf.loglikelihood(ctx, " {}".format(choice))[0] for choice in doc["choices"]]

        return lls

    def process_results(self, doc, results):
        gold = doc["gold"]
        pred = np.argmax(results)

        return {"acc": float(pred == gold), "f1_macro": (gold, pred)}

    def higher_is_better(self):
        return {"acc": True, "f1_macro": True}

    def aggregation(self):
        return {"acc": mean, "f1_macro": f1_score_multiclass_macro}


class RuOpenBookQA(MultipleChoiceTask):
    VERSION = 0
    DATASET_NAME = "ruopenbookqa"

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return False

    def has_test_docs(self):
        return True

    def training_docs(self):
        if self.has_training_docs():
            if self._training_docs is None:
                self._training_docs = list(map(self._process_doc, self.dataset["train"]))
            return self._training_docs

    def validation_docs(self):
        if self.has_validation_docs():
            return map(self._process_doc, self.dataset["validation"])

    def test_docs(self):
        if self.has_test_docs():
            return map(self._process_doc, self.dataset["test"])

    def _process_doc(self, doc):
        query = doc["instruction"].format(**doc["inputs"]).strip()
        choices = list("ABCD")
        if doc["outputs"]:
            gold = choices.index(doc["outputs"])
        else:
            gold = ""

        doc["query"] = query
        doc["choices"] = choices
        doc["gold"] = gold
        return doc

    def doc_to_text(self, doc):
        return doc["query"]

    def doc_to_text_without_instruction(self, doc):
        prompt = "{question}\nA) {option_a}\nB) {option_b}\nC) {option_c}\nD) {option_d}\nОтвет:".format(
            **doc["inputs"]
        ).strip()
        return prompt

    def doc_to_target(self, doc):
        if isinstance(doc["gold"], int):
            gold = doc["choices"][doc["gold"]]
        else:
            gold = ""
        return " " + gold

    def should_decontaminate(self):
        return True

    def doc_to_decontamination_query(self, doc):
        return doc["query"]

    def construct_requests(self, doc, ctx):
        lls = [rf.loglikelihood(ctx, " {}".format(choice))[0] for choice in doc["choices"]]

        return lls

    def process_results(self, doc, results):
        gold = doc["gold"]
        pred = np.argmax(results)

        return {"acc": float(pred == gold), "f1_macro": (gold, pred)}

    def higher_is_better(self):
        return {"acc": True, "f1_macro": True}

    def aggregation(self):
        return {"acc": mean, "f1_macro": f1_score_multiclass_macro}
